{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "534daf7f-4b6b-4357-9a38-9117f72ce9b4",
   "metadata": {},
   "source": [
    "# Step 1: Minimal Octo Inference Example\n",
    "\n",
    "This notebook demonstrates how to load a pre-trained / finetuned Octo checkpoint, run inference on some images, and compare the outputs to the true actions.\n",
    "\n",
    "First, let's start with a minimal example!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bae44461",
   "metadata": {},
   "outputs": [],
   "source": [
    "# run this block if you're using Colab\n",
    "\n",
    "# Download repo\n",
    "!git clone https://github.com/octo-models/octo.git\n",
    "%cd octo\n",
    "# Install repo\n",
    "!pip3 install -e .\n",
    "!pip3 install -r requirements.txt\n",
    "!pip3 install --upgrade \"jax[cuda11_pip]==0.4.20\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
    "!pip install numpy==1.21.1 # to fix colab AttributeError: module 'numpy' has no attribute '_no_nep50_warning', if the error still shows reload"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7229ce10",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['TOKENIZERS_PARALLELISM'] = 'false'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83d34283",
   "metadata": {},
   "outputs": [],
   "source": [
    "from octo.model.octo_model import OctoModel\n",
    "\n",
    "model = OctoModel.load_pretrained(\"hf://rail-berkeley/octo-small-1.5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15fca0dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import requests\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "# download one example BridgeV2 image\n",
    "IMAGE_URL = \"https://rail.eecs.berkeley.edu/datasets/bridge_release/raw/bridge_data_v2/datacol2_toykitchen7/drawer_pnp/01/2023-04-19_09-18-15/raw/traj_group0/traj0/images0/im_12.jpg\"\n",
    "img = np.array(Image.open(requests.get(IMAGE_URL, stream=True).raw).resize((256, 256)))\n",
    "plt.imshow(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e669650f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create obs & task dict, run inference\n",
    "import jax\n",
    "# add batch + time horizon 1\n",
    "img = img[np.newaxis,np.newaxis,...]\n",
    "observation = {\"image_primary\": img, \"timestep_pad_mask\": np.array([[True]])}\n",
    "task = model.create_tasks(texts=[\"pick up the fork\"])\n",
    "action = model.sample_actions(\n",
    "    observation, \n",
    "    task, \n",
    "    unnormalization_statistics=model.dataset_statistics[\"bridge_dataset\"][\"action\"], \n",
    "    rng=jax.random.PRNGKey(0)\n",
    ")\n",
    "print(action)   # [batch, action_chunk, action_dim]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "b2be0d1f",
   "metadata": {},
   "source": [
    "# Step 2: Run Inference on Full Trajectories\n",
    "\n",
    "That was easy! Now let's try to run inference across a whole trajectory and visualize the results!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a51eb166",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install mediapy for visualization\n",
    "!pip install mediapy\n",
    "!pip install opencv-python"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b0f7fd1-5b43-480f-b00f-766248d7f9af",
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "import jax\n",
    "import tensorflow_datasets as tfds\n",
    "import tqdm\n",
    "import mediapy\n",
    "import numpy as np"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "b79053f4-316f-4d2d-81bd-e6e04cfa81bf",
   "metadata": {},
   "source": [
    "## Load Model Checkpoint\n",
    "First, we will load the pre-trained checkpoint using the `load_pretrained()` function. You can specify the path to a checkpoint directory or a HuggingFace path.\n",
    "\n",
    "Below, we are loading directly from HuggingFace.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42c04953-869d-48a8-a2df-e601324e97e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from octo.model.octo_model import OctoModel\n",
    "\n",
    "model = OctoModel.load_pretrained(\"hf://rail-berkeley/octo-small-1.5\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "c298ac8f-da06-41d5-a4a5-145c3080231e",
   "metadata": {},
   "source": [
    "## Load Data\n",
    "Next, we will load a trajectory from the Bridge dataset for testing the model. We will use the publicly available copy in the Open X-Embodiment dataset bucket."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "392bd127",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create RLDS dataset builder\n",
    "builder = tfds.builder_from_directory(builder_dir='gs://gresearch/robotics/bridge/0.1.0/')\n",
    "ds = builder.as_dataset(split='train[:1]')\n",
    "\n",
    "# sample episode + resize to 256x256 (default third-person cam resolution)\n",
    "episode = next(iter(ds))\n",
    "steps = list(episode['steps'])\n",
    "images = [cv2.resize(np.array(step['observation']['image']), (256, 256)) for step in steps]\n",
    "\n",
    "# extract goal image & language instruction\n",
    "goal_image = images[-1]\n",
    "language_instruction = steps[0]['observation']['natural_language_instruction'].numpy().decode()\n",
    "\n",
    "# visualize episode\n",
    "print(f'Instruction: {language_instruction}')\n",
    "mediapy.show_video(images, fps=10)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "b37ffca5",
   "metadata": {},
   "source": [
    "## Run Inference\n",
    "\n",
    "Next, we will run inference over the images in the episode using the loaded model. \n",
    "Below we demonstrate setups for both goal-conditioned and language-conditioned training.\n",
    "Note that we need to feed inputs of the correct temporal window size."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ad64434",
   "metadata": {},
   "outputs": [],
   "source": [
    "WINDOW_SIZE = 2\n",
    "\n",
    "# create `task` dict\n",
    "task = model.create_tasks(goals={\"image_primary\": goal_image[None]})   # for goal-conditioned\n",
    "task = model.create_tasks(texts=[language_instruction])                  # for language conditioned"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74d6b20f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# run inference loop, this model only uses 3rd person image observations for bridge\n",
    "# collect predicted and true actions\n",
    "pred_actions, true_actions = [], []\n",
    "for step in tqdm.trange(len(images) - (WINDOW_SIZE - 1)):\n",
    "    input_images = np.stack(images[step:step+WINDOW_SIZE])[None]\n",
    "    observation = {\n",
    "        'image_primary': input_images,\n",
    "        'timestep_pad_mask': np.full((1, input_images.shape[1]), True, dtype=bool)\n",
    "    }\n",
    "    \n",
    "    # this returns *normalized* actions --> we need to unnormalize using the dataset statistics\n",
    "    actions = model.sample_actions(\n",
    "        observation, \n",
    "        task, \n",
    "        unnormalization_statistics=model.dataset_statistics[\"bridge_dataset\"][\"action\"], \n",
    "        rng=jax.random.PRNGKey(0)\n",
    "    )\n",
    "    actions = actions[0] # remove batch dim\n",
    "\n",
    "    pred_actions.append(actions)\n",
    "    final_window_step = step + WINDOW_SIZE - 1\n",
    "    true_actions.append(np.concatenate(\n",
    "        (\n",
    "            steps[final_window_step]['action']['world_vector'], \n",
    "            steps[final_window_step]['action']['rotation_delta'], \n",
    "            np.array(steps[final_window_step]['action']['open_gripper']).astype(np.float32)[None]\n",
    "        ), axis=-1\n",
    "    ))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "12a5e3f7",
   "metadata": {},
   "source": [
    "## Visualize predictions and ground-truth actions\n",
    "\n",
    "Finally, we will visualize the predicted actions in comparison to the groundtruth actions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a79775d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "ACTION_DIM_LABELS = ['x', 'y', 'z', 'yaw', 'pitch', 'roll', 'grasp']\n",
    "\n",
    "# build image strip to show above actions\n",
    "img_strip = np.concatenate(np.array(images[::3]), axis=1)\n",
    "\n",
    "# set up plt figure\n",
    "figure_layout = [\n",
    "    ['image'] * len(ACTION_DIM_LABELS),\n",
    "    ACTION_DIM_LABELS\n",
    "]\n",
    "plt.rcParams.update({'font.size': 12})\n",
    "fig, axs = plt.subplot_mosaic(figure_layout)\n",
    "fig.set_size_inches([45, 10])\n",
    "\n",
    "# plot actions\n",
    "pred_actions = np.array(pred_actions).squeeze()\n",
    "true_actions = np.array(true_actions).squeeze()\n",
    "for action_dim, action_label in enumerate(ACTION_DIM_LABELS):\n",
    "  # actions have batch, horizon, dim, in this example we just take the first action for simplicity\n",
    "  axs[action_label].plot(pred_actions[:, 0, action_dim], label='predicted action')\n",
    "  axs[action_label].plot(true_actions[:, action_dim], label='ground truth')\n",
    "  axs[action_label].set_title(action_label)\n",
    "  axs[action_label].set_xlabel('Time in one episode')\n",
    "\n",
    "axs['image'].imshow(img_strip)\n",
    "axs['image'].set_xlabel('Time in one episode (subsampled)')\n",
    "plt.legend()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
